import torch
from torch import nn
import time
from logging import Logger
from torch.cuda.amp import autocast


# return Shape: 1 * size * size
def subsequent_mask(size):
    attn_shape = (1, size, size)
    mask = torch.triu(torch.ones(attn_shape), diagonal=1).type(torch.uint8)  # 上三角矩阵
    return mask == 0


class Batch:
    def __init__(self, src, tgt=None, pad=2):
        self.src = src
        self.src_mask = (src != pad).unsqueeze(-2)
        if tgt is not None:
            self.tgt = tgt[:, :-1]
            self.tgt_y = tgt[:, 1:]
            self.tgt_mask = self.make_std_mask(self.tgt, pad)
            self.ntokens = (self.tgt_y != pad).data.sum()

    # return Shape: batch_size * tgt.size(-2) * tgt.size(-1)
    @staticmethod  # 在 pad 处添加掩码
    def make_std_mask(tgt, pad):
        tgt_mask = (tgt != pad).unsqueeze(-2)
        tgt_mask = tgt_mask & subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data)
        return tgt_mask


class TrainState:
    step: int = 0  # 当前 epoch 的 step
    accum_step: int = 0  # 梯度更新累计 step
    samples: int = 0  # 使用的样本总数
    tokens: int = 0  # 处理的 token 总数


class DummyOptimizer(torch.optim.Optimizer):
    def __init__(self):
        self.param_groups = [{"lr": 0}]
        None

    def step(self):
        None

    def zero_grad(self, set_to_none=False):
        None


class DummySchedular:
    def step(self):
        None


def run_epoch(data_iter, model, loss_compute, optimizer, scheduler, scaler,
              mode="train", accum_iter=1, train_state=TrainState(), logger: Logger = None):
    start = time.time()
    total_tokens, total_loss, tokens, n_accum = 0, 0, 0, 0
    for i, batch in enumerate(data_iter):
        with autocast():    # 开启混合精度
            out = model.forward(batch.src, batch.tgt, batch.src_mask, batch.tgt_mask)
            loss, loss_node = loss_compute(out, batch.tgt_y, batch.ntokens)

        if mode == "train" or mode == "train+log":
            scaler.scale(loss_node).backward()  # 反向传播，计算梯度
            # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.)
            train_state.step += 1
            train_state.samples += batch.src.shape[0]
            train_state.tokens += batch.ntokens
            if i % accum_iter == 0:
                # try:
                #     scaler.step(optimizer)  # 根据梯度更新参数
                #     scaler.update()
                # except AssertionError as e:
                #     logger.error(f"GradScaler() raise :{e}")
                #     optimizer.step()
                scaler.step(optimizer)  # 根据梯度更新参数
                scaler.update()
                optimizer.zero_grad(set_to_none=True)  # 清空梯度，防止后续计算梯度累加
                n_accum += 1
                train_state.accum_step += 1
            scheduler.step()
        total_loss += loss
        total_tokens += batch.ntokens
        tokens += batch.ntokens
        if i % 40 == 1 and (mode == "train" or mode == "train+log"):
            lr = optimizer.param_groups[0]["lr"]
            elapsed = time.time() - start
            logger.info(("Epoch Step: %6d | Accumulation Step: %3d | " +
                         "Loss: %6.2f | Tokens per Sec: %7.1f | Learning Rate: %6.1e")
                        % (i, n_accum, loss / batch.ntokens, tokens / elapsed, lr))
            start = time.time()
            tokens = 0
    return total_loss / total_tokens, train_state


def rate(step, model_size, factor, warmup):
    if step == 0:
        step = 1
    return factor * (
            model_size ** (-0.5) * min(step ** (-0.5), step * warmup ** (-1.5))
    )


# 正则化：标签平滑，返回 x 和 target 的 KL 散度损失
class LabelSmoothing(nn.Module):
    # size 词典大小
    # padding_idx 填充索引
    # smoothing 平滑值
    def __init__(self, size, padding_idx, smoothing=0.0):
        super(LabelSmoothing, self).__init__()
        self.criterion = nn.KLDivLoss(reduction="sum")  # KL散度损失， ‘sum’ 表示对所有样本KL散度求和
        self.padding_idx = padding_idx
        self.confidence = 1.0 - smoothing  # 置信度
        self.smoothing = smoothing  # 平滑值
        self.size = size
        self.true_dist = None

    def forward(self, x, target):
        assert x.size(1) == self.size
        true_dist = x.data.clone()
        # 对 true_dist 进行填充，填充值为 smoothing / (size - 2)
        true_dist.fill_(self.smoothing / (self.size - 2))

        # 将 target 中的每个元素作为索引，将 true_dist 中的对应元素置为 confidence
        true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)

        # 将 target 中的 padding_idx 对应的元素置为 0
        true_dist[:, self.padding_idx] = 0

        # 找到 target 中为 padding_idx 的元素的索引
        mask = torch.nonzero(target.data == self.padding_idx)

        # 如果 mask 不为空，将 true_dist 中对应的元素置为 0（掩码位置处取0）
        if mask.dim() > 0:
            true_dist.index_fill_(0, mask.squeeze(), 0.0)
        self.true_dist = true_dist

        # 计算模型输出 x 与平滑处理后的 true_dist 之间的 KL 散度损失
        return self.criterion(x, true_dist.clone().detach())


class SimpleLossCompute:
    def __init__(self, criterion):
        self.criterion = criterion

    def __call__(self, x, y, norm):
        sloss = (  # sloss: scaled loss
                self.criterion(
                    x.contiguous().view(-1, x.size(-1)), y.contiguous().view(-1)
                ) / norm
        )
        # .data 表示没有梯度信息的新的张量，其数据与原始张量共享底层数据，但对这个新的张量进行操作不会影响梯度传播回原始张量
        return sloss.data * norm, sloss


def greedy_decode(model, src, src_mask, max_len, start_symbol):
    memory = model.encode(src, src_mask)
    ys = torch.zeros(1, 1).fill_(start_symbol).type_as(src.data)
    for i in range(max_len - 1):
        out = model.decode(
            memory, src_mask, ys, subsequent_mask(ys.size(1)).type_as(src.data)
        )
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.data[0]
        ys = torch.cat(
            [ys, torch.zeros(1, 1).type_as(src.data).fill_(next_word)], dim=1
        )
    return ys
